{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "FashionMNIST.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9-final"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "lMov-LUSabxx"
      },
      "outputs": [],
      "source": [
        "# Convolutional Neural Networks for Classifying Fashion-MNIST Dataset using Ignite\n",
        "This is a tutorial on using Ignite to train neural network models, setup experiments and validate models.\n",
        "\n",
        "In this notebook, we will be doing classification of images using Convolutional Neural Networks \n",
        "\n",
        "We will be using the [Fashion-MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) Fashion-MNIST is a set of 28x28 greyscale images of clothes.\n",
        "\n",
        "![Fashion MNIST dataset](assets/fashion-mnist.png)\n",
        "\n",
        "Lets get started!"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "7PiU02-fabxz"
      },
      "outputs": [],
      "source": [
        "### Importing libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "A7LJdJfU6pra"
      },
      "outputs": [],
      "source": [
        "# !pip install pytorch-ignite"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "mGjzG-Obabx0"
      },
      "outputs": [],
      "source": [
        "General Data-Science Libraries like numpy, matplotlib and seaborn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "-FlbGAEe6prv"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "xp_4isvHabx5"
      },
      "outputs": [],
      "source": [
        "We import `torch`, `nn` and `functional` modules to create our models.\n",
        "\n",
        "We also import `datasets` and `transforms` from torchvision for loading the dataset and applying transforms to the images in the dataset.\n",
        "\n",
        "We import `Dataloader` for making train and validation loader for loading data into our model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "E-OllAGT6pr-"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch import nn, optim\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "Pj7oLY36abx8"
      },
      "outputs": [],
      "source": [
        "`Ignite` is a High-level library to help with training neural networks in PyTorch. It comes with an `Engine` to setup a training loop, various metrics, handlers and a helpful contrib section! \n",
        "\n",
        "Below we import the following:\n",
        "* **Engine**: Runs a given process_function over each batch of a dataset, emitting events as it goes.\n",
        "* **Events**: Allows users to attach functions to an `Engine` to fire functions at a specific event. Eg: `EPOCH_COMPLETED`, `ITERATION_STARTED`, etc.\n",
        "* **Accuracy**: Metric to calculate accuracy over a dataset, for binary, multiclass, multilabel cases. \n",
        "* **Loss**: General metric that takes a loss function as a parameter, calculate loss over a dataset.\n",
        "* **RunningAverage**: General metric to attach to Engine during training. \n",
        "* **ModelCheckpoint**: Handler to checkpoint models. \n",
        "* **EarlyStopping**: Handler to stop training based on a score function. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "YodbjXPi6psK"
      },
      "outputs": [],
      "source": [
        "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
        "from ignite.metrics import Accuracy, Loss, RunningAverage, ConfusionMatrix\n",
        "from ignite.handlers import ModelCheckpoint, EarlyStopping"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "vO53-X98abx-"
      },
      "outputs": [],
      "source": [
        "The code below first sets up transform using `torhvision transfroms` for converting images to pytorch tensors and normalizing the images.\n",
        "\n",
        "Next, We use `torchvision datasets` for dowloading the fashion mnist dataset and applying transforms which we defined above.\n",
        "\n",
        "* `trainset` contains the training data.\n",
        "* `validationset` contains the validation data\n",
        "\n",
        "Next, We use `pytorch dataloader` for making dataloader from the train and validation sets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Y2dqyxnr6psU"
      },
      "outputs": [],
      "source": [
        "# transform to normalize the data\n",
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                transforms.Normalize((0.5,), (0.5,))])\n",
        "\n",
        "# Download and load the training data\n",
        "trainset = datasets.FashionMNIST('./data', download=True, train=True, transform=transform)\n",
        "train_loader = DataLoader(trainset, batch_size=64, shuffle=True)\n",
        "\n",
        "# Download and load the test data\n",
        "validationset = datasets.FashionMNIST('./data', download=True, train=False, transform=transform)\n",
        "val_loader = DataLoader(validationset, batch_size=64, shuffle=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "Cvzm-Oe1abyB"
      },
      "outputs": [],
      "source": [
        "### CNN Model"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "uf4AVJ44abyC"
      },
      "outputs": [],
      "source": [
        "Explanation of Model Architecture\n",
        "\n",
        "* [Convolutional layers](https://pytorch.org/docs/stable/nn.html#conv2d), the Convolutional layer is used to create a convolution kernel that is convolved with the layer input to produce a tensor of outputs.\n",
        "* [Maxpooling layers](https://pytorch.org/docs/stable/nn.html#maxpool2d), the Maxpooling layer is used to downsample an input representation keeping the most active pixels from the previous layer.\n",
        "* The usual [Linear](https://pytorch.org/docs/stable/nn.html#linear) + [Dropout](https://pytorch.org/docs/stable/nn.html#dropout2d) layers to avoid overfitting and produce a 10-dim output.\n",
        "* We had used [Relu](https://pytorch.org/docs/stable/nn.html#id27) Non Linearity for the model and [logsoftmax](https://pytorch.org/docs/stable/nn.html#log-softmax) at the last layer because we are going to use the [NLLL loss](https://pytorch.org/docs/stable/nn.html#nllloss).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "AST_DtTC6psh"
      },
      "outputs": [],
      "source": [
        "class CNN(nn.Module):\n",
        "    \n",
        "    def __init__(self):\n",
        "        super(CNN, self).__init__()\n",
        "        \n",
        "        self.convlayer1 = nn.Sequential(\n",
        "            nn.Conv2d(1, 32, 3,padding=1),\n",
        "            nn.BatchNorm2d(32),\n",
        "            nn.ReLU(),\n",
        "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
        "        )\n",
        "        \n",
        "        self.convlayer2 = nn.Sequential(\n",
        "            nn.Conv2d(32,64,3),\n",
        "            nn.BatchNorm2d(64),\n",
        "            nn.ReLU(),\n",
        "            nn.MaxPool2d(2)\n",
        "        )\n",
        "        \n",
        "        self.fc1 = nn.Linear(64*6*6,600)\n",
        "        self.drop = nn.Dropout2d(0.25)\n",
        "        self.fc2 = nn.Linear(600, 120)\n",
        "        self.fc3 = nn.Linear(120, 10)\n",
        "        \n",
        "    def forward(self, x):\n",
        "        x = self.convlayer1(x)\n",
        "        x = self.convlayer2(x)\n",
        "        x = x.view(-1,64*6*6)\n",
        "        x = self.fc1(x)\n",
        "        x = self.drop(x)\n",
        "        x = self.fc2(x)\n",
        "        x = self.fc3(x)\n",
        "        \n",
        "        return F.log_softmax(x,dim=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "oqvY1QLBabyE"
      },
      "outputs": [],
      "source": [
        "### Creating Model, Optimizer and Loss"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "8z2zUS5zabyF"
      },
      "outputs": [],
      "source": [
        "Below we create an instance of the CNN model. The model is placed on a device and then a loss function of `negative log likelihood loss` and `Adam optimizer` with learning rate of 0.001 are setup. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "JXyVBBYw6pst"
      },
      "outputs": [],
      "source": [
        "# creating model,and defining optimizer and loss\n",
        "model = CNN()\n",
        "# moving model to gpu if available\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "model.to(device)\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
        "criterion = nn.NLLLoss()"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "o69S4bEvabyI"
      },
      "outputs": [],
      "source": [
        "### Training and Evaluating using Ignite"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "PALFpjYoabyJ"
      },
      "outputs": [],
      "source": [
        "### Instantiating Training and Evaluating Engines\n",
        "\n",
        "Below we create 3 engines, a trainer, an evaluator for the training set and an evaluator for the validation set, by using the `create_supervised_trainer` and `create_supervised_evaluator` and passing the required arguments.\n",
        "\n",
        "We import the metrics from `ignite.metrics` which we want to calculate for the model. Like `Accuracy`, `ConfusionMatrix`, and `Loss` and we pass them to `evaluator` engines which will calculate these metrics for each iteration.\n",
        "\n",
        "* `training_history`: it stores the training loss and accuracy\n",
        "* `validation_history`:it stores the validation loss and accuracy\n",
        "* `last_epoch`: it stores the last epoch untill the model is trained\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "a7EaSDMW6ps5"
      },
      "outputs": [],
      "source": [
        "# defining the number of epochs\n",
        "epochs = 12\n",
        "# creating trainer,evaluator\n",
        "trainer = create_supervised_trainer(model, optimizer, criterion, device=device)\n",
        "metrics = {\n",
        "    'accuracy':Accuracy(),\n",
        "    'nll':Loss(criterion),\n",
        "    'cm':ConfusionMatrix(num_classes=10)\n",
        "}\n",
        "train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)\n",
        "val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)\n",
        "training_history = {'accuracy':[],'loss':[]}\n",
        "validation_history = {'accuracy':[],'loss':[]}\n",
        "last_epoch = []"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "aonym06RabyL"
      },
      "outputs": [],
      "source": [
        "### Metrics - RunningAverage\n",
        "\n",
        "To start, we will attach a metric of `RunningAverage` to track a running average of the scalar loss output for each batch. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4LXaAfga6ptD"
      },
      "outputs": [],
      "source": [
        "RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "2vDAID6BabyP"
      },
      "outputs": [],
      "source": [
        "### EarlyStopping - Tracking Validation Loss\n",
        "\n",
        "Now we will setup a `EarlyStopping` handler for this training process. EarlyStopping requires a score_function that allows the user to define whatever criteria to stop trainig. In this case, if the loss of the validation set does not decrease in 10 epochs, the training process will stop early. Since the `EarlyStopping` handler relies on the validation loss, it's attached to the `val_evaluator`. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "0mcIq0lc6ptN"
      },
      "outputs": [],
      "source": [
        "def score_function(engine):\n",
        "    val_loss = engine.state.metrics['nll']\n",
        "    return -val_loss\n",
        "\n",
        "handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)\n",
        "val_evaluator.add_event_handler(Events.COMPLETED, handler)"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "GZzf0wtLabyR"
      },
      "outputs": [],
      "source": [
        "### Attaching Custom Functions to Engine at specific Events\n",
        "\n",
        "Below you will see ways to define your own custom functions and attaching them to various `Events` of the training process.\n",
        "\n",
        "The functions below both achieve similar tasks, they print the results of the evaluator run on a dataset. One function does that on the training evaluator and dataset, while the other on the validation. Another difference is how these functions are attached in the trainer engine.\n",
        "\n",
        "The first method involves using a decorator, the syntax is simple - `@` `trainer.on(Events.EPOCH_COMPLETED)`, means that the decorated function will be attached to the trainer and called at the end of each epoch. \n",
        "\n",
        "The second method involves using the add_event_handler method of trainer - `trainer.add_event_handler(Events.EPOCH_COMPLETED, custom_function)`. This achieves the same result as the above. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "UFQNaZDx6ptV"
      },
      "outputs": [],
      "source": [
        "@trainer.on(Events.EPOCH_COMPLETED)\n",
        "def log_training_results(trainer):\n",
        "    train_evaluator.run(train_loader)\n",
        "    metrics = train_evaluator.state.metrics\n",
        "    accuracy = metrics['accuracy']*100\n",
        "    loss = metrics['nll']\n",
        "    last_epoch.append(0)\n",
        "    training_history['accuracy'].append(accuracy)\n",
        "    training_history['loss'].append(loss)\n",
        "    print(\"Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}\"\n",
        "          .format(trainer.state.epoch, accuracy, loss))\n",
        "\n",
        "def log_validation_results(trainer):\n",
        "    val_evaluator.run(val_loader)\n",
        "    metrics = val_evaluator.state.metrics\n",
        "    accuracy = metrics['accuracy']*100\n",
        "    loss = metrics['nll']\n",
        "    validation_history['accuracy'].append(accuracy)\n",
        "    validation_history['loss'].append(loss)\n",
        "    print(\"Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}\"\n",
        "          .format(trainer.state.epoch, accuracy, loss))\n",
        "    \n",
        "trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)    "
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "z4ZHDOs1abyT"
      },
      "outputs": [],
      "source": [
        "### Confusion Matrix"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "j5NOsTQxabyU"
      },
      "outputs": [],
      "source": [
        "Confusion matrix gives us a better idea of what our classification model is getting right and what types of errors it is making.\n",
        "\n",
        "We visualize the `confusion matrix` using the `seaborn.heatmap` from `seaborn` library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kVrVyE8X6pth"
      },
      "outputs": [],
      "source": [
        "@trainer.on(Events.COMPLETED)\n",
        "def log_confusion_matrix(trainer):\n",
        "    val_evaluator.run(val_loader)\n",
        "    metrics = val_evaluator.state.metrics\n",
        "    cm = metrics['cm']\n",
        "    cm = cm.numpy()\n",
        "    cm = cm.astype(int)\n",
        "    classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']\n",
        "    fig, ax = plt.subplots(figsize=(10,10))  \n",
        "    ax= plt.subplot()\n",
        "    sns.heatmap(cm, annot=True, ax = ax,fmt=\"d\")\n",
        "    # labels, title and ticks\n",
        "    ax.set_xlabel('Predicted labels')\n",
        "    ax.set_ylabel('True labels') \n",
        "    ax.set_title('Confusion Matrix') \n",
        "    ax.xaxis.set_ticklabels(classes,rotation=90)\n",
        "    ax.yaxis.set_ticklabels(classes,rotation=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "HX-NImoSabyW"
      },
      "outputs": [],
      "source": [
        "### ModelCheckpoint\n",
        "\n",
        "Lastly, we want to checkpoint this model. It's important to do so, as training processes can be time consuming and if for some reason something goes wrong during training, a model checkpoint can be helpful to restart training from the point of failure.\n",
        "\n",
        "Below we will use Ignite's `ModelCheckpoint` handler to checkpoint models at the end of each epoch."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "qGVhrmoZ6ptp"
      },
      "outputs": [],
      "source": [
        "checkpointer = ModelCheckpoint('./saved_models', 'fashionMNIST', n_saved=2, create_dir=True, save_as_state_dict=True, require_empty=False)\n",
        "trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'fashionMNIST': model})"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "4YK6lGPHabyY"
      },
      "outputs": [],
      "source": [
        "### Run Engine\n",
        "\n",
        "Next, we will run the trainer for 12 epochs and monitor results. Below we can see that custom functions defined above helps prints the `loss` and `accuracy` per epoch.  \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "lW6m6Stp6pty"
      },
      "outputs": [],
      "source": [
        "trainer.run(train_loader, max_epochs=epochs)"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "hTwbdtI_abyg"
      },
      "outputs": [],
      "source": [
        "### Plotting the loss and accuracy\n",
        "Next, we will plot the loss and accuracy which we have stored in the `training_history` and `validation_history` dictionary to see how loss and accuracy are changing with each epoch."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "SptDgBR66pt4"
      },
      "outputs": [],
      "source": [
        "plt.plot(training_history['accuracy'],label=\"Training Accuracy\")\n",
        "plt.plot(validation_history['accuracy'],label=\"Validation Accuracy\")\n",
        "plt.xlabel('No. of Epochs')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.legend(frameon=False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "g4iomcl36pt-"
      },
      "outputs": [],
      "source": [
        "plt.plot(training_history['loss'],label=\"Training Loss\")\n",
        "plt.plot(validation_history['loss'],label=\"Validation Loss\")\n",
        "plt.xlabel('No. of Epochs')\n",
        "plt.ylabel('Loss')\n",
        "plt.legend(frameon=False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "NW9ze2HCabyp"
      },
      "outputs": [],
      "source": [
        "### Loading the saved model from the disk\n",
        "Loading the saved pytorch model from the disk for inferencing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "SU4y96CI6puH"
      },
      "outputs": [],
      "source": [
        "# loading the saved model\n",
        "def fetch_last_checkpoint_model_filename(model_save_path):\n",
        "    import os\n",
        "    checkpoint_files = os.listdir(model_save_path)\n",
        "    checkpoint_files = [f for f in checkpoint_files if '.pth' in f]\n",
        "    checkpoint_iter = [\n",
        "        int(x.split('_')[2].split('.')[0])\n",
        "        for x in checkpoint_files]\n",
        "    last_idx = np.array(checkpoint_iter).argmax()\n",
        "    return os.path.join(model_save_path, checkpoint_files[last_idx])\n",
        "\n",
        "model.load_state_dict(torch.load(fetch_last_checkpoint_model_filename('./saved_models')))\n",
        "print(\"Model Loaded\")"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "QNRWgV0Kabys"
      },
      "outputs": [],
      "source": [
        "### Inferencing the model \n",
        "Below code will be used for inferencing from the model and visualizing the results.\n",
        "\n",
        "Here we do iteration from the `val_loader` and then select the class with highest probability and then compare it with actul class."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "uWojxHNy6puN"
      },
      "outputs": [],
      "source": [
        "# classes of fashion mnist dataset\n",
        "classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']\n",
        "# creating iterator for iterating the dataset\n",
        "dataiter = iter(val_loader)\n",
        "images, labels = dataiter.next()\n",
        "images_arr = []\n",
        "labels_arr = []\n",
        "pred_arr = []\n",
        "# moving model to cpu for inference \n",
        "model.to(\"cpu\")\n",
        "# iterating on the dataset to predict the output\n",
        "for i in range(0,10):\n",
        "    images_arr.append(images[i].unsqueeze(0))\n",
        "    labels_arr.append(labels[i].item())\n",
        "    ps = torch.exp(model(images_arr[i]))\n",
        "    ps = ps.data.numpy().squeeze()\n",
        "    pred_arr.append(np.argmax(ps))\n",
        "# plotting the results\n",
        "fig = plt.figure(figsize=(25,4))\n",
        "for i in range(10):\n",
        "    ax = fig.add_subplot(2, 20/2, i+1, xticks=[], yticks=[])\n",
        "    ax.imshow(images_arr[i].resize_(1, 28, 28).numpy().squeeze())\n",
        "    ax.set_title(\"{} ({})\".format(classes[pred_arr[i]], classes[labels_arr[i]]),\n",
        "                 color=(\"green\" if pred_arr[i]==labels_arr[i] else \"red\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "execution_count": null,
      "metadata": {
        "colab_type": "text",
        "id": "cWXOEpQ4abyv"
      },
      "outputs": [],
      "source": [
        "### Refrences \n",
        "* [Pytorch Ignite Text CNN example notebook](https://github.com/pytorch/ignite/blob/master/examples/notebooks/TextCNN.ipynb)\n",
        "* [Pytorch Ignite MNIST example](https://github.com/pytorch/ignite/blob/master/examples/mnist/mnist.py)"
      ]
    }
  ]
}